# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

def fp(p):
    fp = 10*np.log(T)
#    fp = 1 #simulate the centralized setting
    return int(fp)


N = 10 #repeat times

C = 1 #communication loss

global T 
T = int(1e6)
sigma = 1/2
# regret = np.zeros([N,T])


K = 10
M=8
mu_local = np.array([[0.68269, 0.86294, 0.1709 , 0.82458, 0.07878, 0.24676, 0.90536, 0.96767,
  0.59301, 0.61714],
 [0.43568, 0.55027, 0.68145, 0.93006, 0.56557, 0.4975 , 0.81858, 0.11724,
  0.89142, 0.96592],
 [0.9078 , 0.95654, 0.86752, 0.06922, 0.8344 , 0.68833, 0.5249 , 0.70903,
  0.71558, 0.45445],
 [0.78629, 0.68128, 0.91343, 0.6928 , 0.99583, 0.99903, 0.98784, 0.737  ,
  0.32008, 0.71274],
 [0.87981, 0.90911, 0.91894, 0.98812, 0.98992, 0.99487, 0.56838, 0.92776,
  0.83553, 0.74845],
 [0.98665, 0.75137, 0.99478, 0.99778, 0.99889, 0.99994, 0.77171, 0.99999,
  0.97195, 0.99368],
 [0.92108, 0.85349, 0.98298, 0.99244, 0.99661, 0.99857, 0.81323, 0.89631,
  0.99243, 0.79262],
 [0.8    , 0.795  , 0.79   , 0.785  , 0.78   , 0.775  , 0.77   , 0.765  ,
  0.76   , 0.755  ]])

mu_global = np.mean(mu_local, axis=0)
global best_arm
best_arm = int(np.argmax(mu_global))
print(best_arm)
def pull_arm(mu):
    X = np.random.uniform(0, 1)
    return 1 if X < mu else 0


def get_bits(num):
    num = num if num > 0 else 1
    return int(np.ceil(1 + np.log2(num)))
comm_c = 0

regret_list = []
ind_reg_list = []
comm_times_list = []
comm_bits_list = []
for rep in tqdm(range(N)):
    t = 1
    p = 0

    active_arm = np.array(range(K),dtype = int)
    pull_num = np.zeros([M,K])
    reward_local = np.zeros([M,K])
    reward_t = np.zeros((M, T))
    reward_global = np.zeros(T)
    opt_rw = np.zeros((M, T))
    optimal_reward = np.zeros(T)
    regret = np.zeros((M, T))
    comm_times = np.zeros(T)
    comm_bits = np.zeros(T)

    
    data_local = np.zeros([M,K,T])#M*K*T
    data_global = np.zeros([K,T]) #K*T
    

    for j in range(M):
        for i in range(K):
            # data_local[j,i] = np.random.normal(mu_local[j,i],sigma,T)
            data_local[j,i] = np.array([pull_arm(mu_local[j,i]) for _ in range(T)])
    
    optimal_index = best_arm
    for i in range(K):
        # data_global[i] = np.random.normal(mu_global[i],sigma, T)
        data_global[i] = np.array([pull_arm(mu_global[i]) for _ in range(T)])
    
    while t<T:
        '''
        round p
        '''
        
        '''
        local players
        '''
        
        if len(active_arm)>1:
            expl_len = fp(p)
            p += 1
            for k in active_arm:
                for _ in range(min(T-t,expl_len)):
                    for m in range(M):
                        # if t >= T: break
                        reward_local[m,k] += data_local[m,k,t]
                        reward_t[m, t] = reward_t[m, t - 1] + data_global[k,t]
                        opt_rw[m, t] = opt_rw[m, t - 1] + data_global[optimal_index,t]
                        pull_num[m,k] += 1
                        comm_times[t] += M
                        comm_bits[t] += get_bits(reward_global[-1]) * M * len(reward_global)
                    reward_global[t] = reward_global[t-1]+M*data_global[k,t]
                    optimal_reward[t] = optimal_reward[t-1]+M*data_global[optimal_index,t]
                    t = t+1
            mu_local_sample = reward_local/pull_num
            if t >= T: break
            comm_times[t] += 1
            comm_bits[t] += get_bits(reward_global[-1]) * len(reward_global)
               
        if len(active_arm)==1:
            reward_global[t:] = reward_global[t-1]+np.arange(T-t)*M*mu_global[active_arm[0]]
            optimal_reward[t:] = optimal_reward[t-1]+np.arange(T-t)*M*mu_global[optimal_index]
            for m in range(M):
                reward_t[m, t:] = reward_t[m, t - 1] + np.arange(T-t)*mu_global[active_arm[0]]
                opt_rw[m, t:] = opt_rw[m, t - 1] + np.arange(T-t)*mu_global[optimal_index]
            comm_times[t] += M * (T - t)
            comm_bits[t] += get_bits(reward_global[-1]) * M * len(reward_global) * (T - t)
            break
        
        '''
        global server
        '''
        if len(active_arm)>1:
            comm_c += M
            reward_global[t - 1] -= C*M #comment this line out to ignore communication loss
            E = np.array([])
            comm_times[t] += M
            comm_bits[t] += get_bits(mu_local_sample[-1, -1]) * M * len(mu_local_sample)
            mu_global_sample = 1/M*sum(mu_local_sample)
            conf_bnd = np.sqrt(4*sigma**2*np.log(T)/(M*pull_num[0,active_arm[0]])) #the constants are tuned from the original ones in the paper to get better performance
            elm_max = np.nanmax(mu_global_sample)-conf_bnd
            for index in range(len(active_arm)):
                arm = active_arm[index]
                if mu_global_sample[arm]+conf_bnd<elm_max:
                    E = np.append(E,np.array([arm]))
        
            for i in range(len(E)):
                active_arm = np.delete(active_arm, np.where(active_arm == E[i]))
    
    for t in range(1, T):
        comm_times[t] += comm_times[t - 1]
        comm_bits[t] += comm_bits[t - 1]
    ind = np.zeros((M, T))
    regret_list.append(np.array(optimal_reward - reward_global))
    comm_times_list.append(comm_times)
    comm_bits_list.append(comm_bits)
    for m in range(M):
        for t in range(T):
            ind[m, t] = opt_rw[m, t] - reward_t[m, t]
    ind_reg_list.append(ind)
regret_list = np.array(regret_list)
comm_times_list = np.array(comm_times_list)
comm_bits_list = np.array(comm_bits_list)
# ind_reg_list = np.array(ind_reg_list)
# print(ind_reg_list.shape)
np.save('~/var_delta/data/feducb/group_regret_list_mu3.npy', regret_list)
np.save('~/var_delta/data/feducb/comm_times_list_mu3.npy', comm_times_list)
np.save('~/var_delta/data/feducb/comm_bits_list_mu3.npy', comm_bits_list)
np.save('~/var_delta/data/feducb/ind_reg_list_mu3.npy', ind_reg_list)